using Tables, Plots, CSV, DataFrames, MAT, Distributions, LinearAlgebra, SparseArrays, JLD2, HDF5, Random, StatsFuns
include("setup.jl")
include("getBounds.jl")
include("mcmcfunctions.jl")
include("counterfactualFunctionsNew.jl")
include("modelfit.jl")
include("tableHelperFunctions.jl")

const getStartvals = false #if true, run on subsample and save to "chain1", otherwise save to "chain2"
const estimate = true
const makeTables = !getStartvals
const simulateCounterfactuals = !getStartvals
const runMinimalCounterfactuals = false #if true, skip the scale-alpha and drop-selectivity-decile counterfactuals

const datapath = "G:/My Drive/4_clean"
const startvalpath = "G:/My Drive/4_clean/from_Julia/chain1"
const figurepath = "C:/Users/akapor/Documents/GitHub/Platform-Externalities/5_paper/Round 2/Paper/figures/mcmc"
const tablepath = "C:/Users/akapor/Documents/GitHub/Platform-Externalities/5_paper/Round 2/Paper/tables"

Random.seed!(12345)
const model = ProgramFE()
const years = 2010:2012
const fname_theta = startvalpath*"/params_20000.mat" #file to load theta from
const fname_lv = Dict(yr=>startvalpath*"/latentvariables$(yr)_20000.mat" for yr in years)

if getStartvals
    const _sample_fraction = 0.10 #set to 1.0 for full sample
    const loadFromMatlab = true
    const loadSavedTheta = false  #"parameters", could come from a subset
    const loadSavedLV = false    #rest of state: utilities, etc.
    const outputpath = "G:/My Drive/4_clean/from_Julia/chain1"
    iter = 0:20000 #0:5000
    save_params = 10000:125:20000
    save_lv = 10000:250:20000 #2500:100:5000 #2500:100:5000
    save_counterfactuals = -1:-1 #_sample_fraction==1 ? (0:250:2500) : (-1:-1)
else
    const _sample_fraction = 1.0 #set to 1.0 for full sample
    const loadFromMatlab = true
    const loadSavedTheta = true  #"parameters", could come from a subset
    const loadSavedLV = false    #rest of state: utilities, etc.
    const outputpath = "G:/My Drive/4_clean/from_Julia/chain2"
    iter = 0:7500 #0:5000
    save_params = 2500:50:7500
    save_lv = 2500:200:7500
    save_counterfactuals = -1:-1 #_sample_fraction==1 ? (0:250:2500) : (-1:-1)
end

if loadFromMatlab
    dataset = Dict{Int,Any}()
    counterfactuals = Dict{Int,Any}()
    for year in years
        println("loading prices ... ")
        Pij = CSV.read("$datapath/P$year.csv",Tables.matrix, header=false) ./ 1e6
        print("loading data from .mat file: year $year ... ")
        MCMC =  matread(datapath*"/export_to_julia_$(year).mat")["MCMC$year"]
        println("done.")
        MCMC["Pij"] = Pij
        println("setup: $(100*_sample_fraction)% sample")
        mysample = findall([rand() < _sample_fraction for ii=1:Int(MCMC["I"])])
        constants = setupData(MCMC,year,mysample,model)
        cleanEnrollment!(constants)
        lv, temps = getLatentVars(constants)
        startvals!(constants,lv,temps)
        checkValidity(constants,lv)
        dataset[year] = (constants,lv,temps)
        if year==2012
            cf_constants = let
                priorities  = Array{Int,2}(MCMC["Priorities"])[mysample,:]
                eligible = Bool.(MCMC["EligibleToApply"])[mysample,:]
                (II,J) = size(priorities)
                _p = priorities .* Array{Bool,2}(eligible)
                priorityEligible = cat(_p,ones(Int,II),dims=2)
                seats = vec(Int.(ceil.(MCMC["seats"] .* _sample_fraction)))
                sobrecupo  = vec(Int.(ceil.(MCMC["sobrecupo"] .* _sample_fraction)))
                institutionID = vec(Int.(MCMC["TypeIndex"][:,3]))
                Pij = MCMC["Pij"]
                G8 = institutionID .> 25
                @assert 8 == length(unique(institutionID[G8]))
                cf_constants = (priorityEligible = priorityEligible,
                                seats=seats, sobrecupo=sobrecupo,
                                institutionID=institutionID, G8=G8, Pij=Pij)
            end
            cf_temps = let
                (J,I) = size(lv.U)
                Jtilde = J+1
                alphadraws = rand(Jtilde,I)
                alphadraws[Jtilde,:] .= 1 #2nd oo is never unavailable
                Ufull = zeros(Jtilde,I)
                (proposalcount=zeros(Int,Jtilde),point_at=zeros(Int,I),
                enroll = zeros(Int,Jtilde), U_pointed=zeros(I),Ufull=Ufull,
                programPrefs=zeros(Int,I,Jtilde), programROL=zeros(Int,I,Jtilde),
                alphadraws = alphadraws,
                Xoutcome = similar(temps.Xoutcome),
                )
            end
            counterfactuals[year] = (cf_constants,cf_temps)
        end
    end
    #println("saving data")
    #JLD2.@save(datapath*"/"*datafilename,dataset,counterfactuals)
else
    error("file was too big to save so I can't load it")
    #JLD2.@load(datapath*"/"*datafilename,dataset,counterfactuals)
end

for year in keys(dataset)
    (cc,lv,temps) = dataset[year]
    dataset[year] = (makeAdditionalConstants(cc,lv),lv,temps)
end
const myRNG = [Random.MersenneTwister(1234*n) for n=1:Threads.nthreads()]
dataset = Dict{Int64,typeof(dataset[2010])}(dataset)

if !loadSavedTheta
    theta = reinitialize!(dataset,model)
else
    theta = let
        theta_dict = matread(fname_theta)
        mykeys = Symbol.(keys(theta_dict)) |> Tuple
        NamedTuple{mykeys}([theta_dict[string(k)] for k in mykeys])
    end
end

if loadSavedLV
    for yr in keys(dataset)
        println("loading state: year $yr")
        JLD2.@load(fname_lv[yr],lv_out)
        lv_yr = dataset[yr][2]
        for field in keys(lv_yr)
            lv_yr[field] .= lv_out[field]
        end
    end
end

priors = let
    constants = dataset[2011][1]
    n_matchterms = size(constants.Xij[1],2)
    n_fixed = size(constants.Xfixed,2)
    n_rc = size(constants.Xrc,2)
    n_oo = size(constants.Xoo,2)
    n_alpha = size(constants.Xfriction_full[1],2)
    (
    ij=MvNormal(zeros(n_matchterms),10*I),
    fixed=MvNormal(zeros(n_fixed),10*I),
    oo0 = MvNormal(zeros(n_oo),10*diagm(ones(n_oo))),
    oo1 = MvNormal(zeros(n_oo),10*diagm(ones(n_oo))),
    alpha = MvNormal(zeros(n_alpha),10*diagm(ones(n_alpha))),
    VCVrc=InverseWishart(n_rc+1,diagm(10 .* ones(n_rc))),
    var_oo0=InverseGamma(100,100),
    var_oo1=InverseGamma(100,100),
    outcome = MvNormal(zeros(1+n_oo+n_matchterms+n_fixed),10*I),
    )
end

if estimate
    if loadSavedTheta
        loadSavedLV || justUpdateIndividuals!(theta,dataset,priors,5000)
    else
        justUpdateIndividuals!(theta,dataset,priors,5)
    end
    history = typeof(theta)[]
end

if estimate #run it!!!
    println("estimating")
    mycallback = (theta,dataset) -> begin
        println("varU: $(var(dataset[2011][2].U))")
        println("var of program FE: $(var(theta.beta_fixed[2:end,:],dims=1))")
        println("var of oo's: $(theta.sigsqOO0), $(theta.sigsqOO1)")
        println("price coef: $(theta.beta_ij[1,:])")
        #println("price coef (non-teaching): $(theta.beta_ij[2,:])")
        println("coefs on utility: $(theta.betaOutcome[1,:])")
        println("alpha: $(theta.alpha)")
        # for year in keys(dataset)
        #     (cc,lv,temps) = dataset[year]
        #     u0m = mean(log.(view(lv.U0,1,:)) .- view(temps.mu0,1,:))
        #     u1m = mean(log.(view(lv.U0,2,:)) .- view(temps.mu0,2,:))
        #     println("outside-option fit, year $year: $([u0m,u1m])")
        # end
        false
    end
    MCMC!(theta,dataset,counterfactuals,priors,iter,save_params=save_params, save_lv=save_lv,
        save_counterfactuals=save_counterfactuals, history=history, callback=mycallback, model=model)
else
    history = typeof(theta)[]
    for t in save_params
        f = matopen(outputpath*"/params_$t.mat")
            _theta_t = MAT.read(f)
        close(f)
        _keys = (collect(keys(theta))...,)
        _vals = ([_theta_t[string(k)] for k in _keys]...,)
        theta_t = NamedTuple{Symbol.(_keys)}(_vals)
        push!(history,theta_t)
    end
end
makeTables && include("describeParameterEstimates.jl")
simulateCounterfactuals && begin
    include("runCounterfactuals.jl")
    include("describeCounterfactualSimulations.jl")
end

#=
Note: when done, need to run simulateOneYear.jl to do model-fit exercises.
this file can also be run separately
=#
include("simulateOneYear.jl")
include("eventStudyWithinModel.jl")  #appendix table